{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "612b6a05",
   "metadata": {},
   "source": [
    "# Distributed training of an XGBoost model\n",
    "\n",
    "\n",
    "<div align=\"left\">\n",
    "<a target=\"_blank\" href=\"https://console.anyscale.com/\"><img src=\"https://img.shields.io/badge/🚀 Run_on-Anyscale-9hf\"></a>&nbsp;\n",
    "<a href=\"https://github.com/anyscale/e2e-xgboost\" role=\"button\"><img src=\"https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d\"></a>&nbsp;\n",
    "</div>\n",
    "\n",
    "This tutorial executes a distributed training workload that connects the following steps with heterogeneous compute requirements:\n",
    "- Preprocessing the dataset with Ray Data\n",
    "- Distributed training of an XGBoost model with Ray Train\n",
    "- Saving model artifacts to a model registry with MLflow\n",
    "\n",
    "**Note**: This tutorial doesn't including tuning of the model. See [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for experiment execution and hyperparameter tuning.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/distributed_training.png\" width=800>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f778369e",
   "metadata": {},
   "source": [
    "## Dependencies\n",
    "\n",
    "To install the dependencies, run the following:\n",
    "\n",
    "```bash\n",
    "pip install -r requirements.txt\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab9d9875",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Import the necessary modules:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f5493d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cba9f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable importing from dist_xgboost module.\n",
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(os.path.abspath(\"..\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23ddcfbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable Ray Train v2. This is the default in an upcoming release.\n",
    "os.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\"\n",
    "# Now it's safe to import from ray.train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15cfb416",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ray\n",
    "\n",
    "from dist_xgboost.constants import storage_path, preprocessor_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05f79e20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make Ray data less verbose.\n",
    "ray.data.DataContext.get_current().enable_progress_bars = False\n",
    "ray.data.DataContext.get_current().print_on_execution_start = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ad88db8",
   "metadata": {},
   "source": [
    "## Dataset preparation\n",
    "\n",
    "This example uses the [Breast Cancer Wisconsin (Diagnostic)](https://archive.ics.uci.edu/dataset/17/breast+cancer+wisconsin+diagnostic) dataset, which contains features computed from digitized images of breast mass cell nuclei.\n",
    "\n",
    "Split the data into:\n",
    "- 70% for training\n",
    "- 15% for validation\n",
    "- 15% for testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1036655e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.data import Dataset\n",
    "\n",
    "\n",
    "def prepare_data() -> tuple[Dataset, Dataset, Dataset]:\n",
    "    \"\"\"Load and split the dataset into train, validation, and test sets.\"\"\"\n",
    "    # Load the dataset from S3.\n",
    "    dataset = ray.data.read_csv(\"s3://anonymous@air-example-data/breast_cancer.csv\")\n",
    "    seed = 42\n",
    "\n",
    "    # Split 70% for training.\n",
    "    train_dataset, rest = dataset.train_test_split(test_size=0.3, shuffle=True, seed=seed)\n",
    "    # Split the remaining 30% into 15% validation and 15% testing.\n",
    "    valid_dataset, test_dataset = rest.train_test_split(test_size=0.5, shuffle=True, seed=seed)\n",
    "    return train_dataset, valid_dataset, test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b0f220",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-16 21:01:53,956\tINFO worker.py:1660 -- Connecting to existing Ray cluster at address: 10.0.23.200:6379...\n",
      "2025-04-16 21:01:53,966\tINFO worker.py:1843 -- Connected to Ray cluster. View the dashboard at \u001b[1m\u001b[32mhttps://session-1kebpylz8tcjd34p4sv2h1f9tg.i.anyscaleuserdata.com \u001b[39m\u001b[22m\n",
      "2025-04-16 21:01:53,972\tINFO packaging.py:575 -- Creating a file package for local module '/home/ray/default/e2e-xgboost/dist_xgboost'.\n",
      "2025-04-16 21:01:53,975\tINFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_aa0e5fd0ec6b8edc.zip' (0.02MiB) to Ray cluster...\n",
      "2025-04-16 21:01:53,976\tINFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_aa0e5fd0ec6b8edc.zip'.\n",
      "2025-04-16 21:01:53,977\tINFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_38ec1ca756a7ccf23a0c590d356f26fc87860d8a.zip' (0.07MiB) to Ray cluster...\n",
      "2025-04-16 21:01:53,978\tINFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_38ec1ca756a7ccf23a0c590d356f26fc87860d8a.zip'.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(autoscaler +11s)\u001b[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.\n",
      "\u001b[36m(autoscaler +11s)\u001b[0m [autoscaler] [8CPU-32GB] Attempting to add 1 node(s) to the cluster (increasing from 0 to 1).\n",
      "\u001b[36m(autoscaler +11s)\u001b[0m [autoscaler] [8CPU-32GB] Launched 1 instances.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-16 21:03:12,957\tINFO dataset.py:2809 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'mean radius': 19.16,\n",
       "  'mean texture': 26.6,\n",
       "  'mean perimeter': 126.2,\n",
       "  'mean area': 1138.0,\n",
       "  'mean smoothness': 0.102,\n",
       "  'mean compactness': 0.1453,\n",
       "  'mean concavity': 0.1921,\n",
       "  'mean concave points': 0.09664,\n",
       "  'mean symmetry': 0.1902,\n",
       "  'mean fractal dimension': 0.0622,\n",
       "  'radius error': 0.6361,\n",
       "  'texture error': 1.001,\n",
       "  'perimeter error': 4.321,\n",
       "  'area error': 69.65,\n",
       "  'smoothness error': 0.007392,\n",
       "  'compactness error': 0.02449,\n",
       "  'concavity error': 0.03988,\n",
       "  'concave points error': 0.01293,\n",
       "  'symmetry error': 0.01435,\n",
       "  'fractal dimension error': 0.003446,\n",
       "  'worst radius': 23.72,\n",
       "  'worst texture': 35.9,\n",
       "  'worst perimeter': 159.8,\n",
       "  'worst area': 1724.0,\n",
       "  'worst smoothness': 0.1782,\n",
       "  'worst compactness': 0.3841,\n",
       "  'worst concavity': 0.5754,\n",
       "  'worst concave points': 0.1872,\n",
       "  'worst symmetry': 0.3258,\n",
       "  'worst fractal dimension': 0.0972,\n",
       "  'target': 0}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load and split the dataset.\n",
    "train_dataset, valid_dataset, _test_dataset = prepare_data()\n",
    "train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b65de1dd",
   "metadata": {},
   "source": [
    "Look at the output to see that the dataset contains features characterizing cell nuclei in breast mass, such as radius, texture, perimeter, area, smoothness, compactness, concavity, symmetry, and more."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56e67eb1",
   "metadata": {},
   "source": [
    "## Data preprocessing\n",
    "\n",
    "Notice that the features have different magnitudes and ranges. While tree-based models like XGBoost aren't as sensitive to these differences, feature scaling can still improve numerical stability in some cases.\n",
    "\n",
    "Ray Data has built-in preprocessors that simplify common feature preprocessing tasks, especially for tabular data. You can integrate these preprocessors with Ray Datasets, to preprocess data in a fault-tolerant and distributed way.\n",
    "\n",
    "This example uses Ray's built-in `StandardScaler` to zero-center and normalize the features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7256185",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.data.preprocessors import StandardScaler\n",
    "\n",
    "\n",
    "def train_preprocessor(train_dataset: ray.data.Dataset) -> StandardScaler:\n",
    "    # Pick some dataset columns to scale.\n",
    "    columns_to_scale = [c for c in train_dataset.columns() if c != \"target\"]\n",
    "\n",
    "    # Initialize the preprocessor.\n",
    "    preprocessor = StandardScaler(columns=columns_to_scale)\n",
    "    # Train the preprocessor on the training set.\n",
    "    preprocessor.fit(train_dataset)\n",
    "\n",
    "    return preprocessor\n",
    "\n",
    "\n",
    "preprocessor = train_preprocessor(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19daa596",
   "metadata": {},
   "source": [
    "Now that you've fit the preprocessor, save it to a file. Register this artifact later in MLflow so you can reuse it in downstream pipelines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2688e721",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(preprocessor_path, \"wb\") as f:\n",
    "    pickle.dump(preprocessor, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4ff2165",
   "metadata": {},
   "source": [
    "Next, transform the datasets using the fitted preprocessor. Note that the `transform()` operation is lazy. Ray Data won't apply it to the data until the Ray Train workers require the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "230223b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'mean radius': 1.3883915483364895,\n",
       "  'mean texture': 1.6582900738074817,\n",
       "  'mean perimeter': 1.3686612092802328,\n",
       "  'mean area': 1.3271629358408426,\n",
       "  'mean smoothness': 0.3726369329455741,\n",
       "  'mean compactness': 0.7709391453349583,\n",
       "  'mean concavity': 1.2156484038771678,\n",
       "  'mean concave points': 1.1909841981870102,\n",
       "  'mean symmetry': 0.33295997290846857,\n",
       "  'mean fractal dimension': -0.07207903519571106,\n",
       "  'radius error': 0.8074600624242092,\n",
       "  'texture error': -0.3842391069975234,\n",
       "  'perimeter error': 0.6925593054563496,\n",
       "  'area error': 0.5852832746827147,\n",
       "  'smoothness error': 0.13331319500721583,\n",
       "  'compactness error': -0.03934175265392654,\n",
       "  'concavity error': 0.22009334597724586,\n",
       "  'concave points error': 0.16570998568362863,\n",
       "  'symmetry error': -0.7220900323187186,\n",
       "  'fractal dimension error': -0.13670701917436776,\n",
       "  'worst radius': 1.5076654048043645,\n",
       "  'worst texture': 1.6169142713721316,\n",
       "  'worst perimeter': 1.5267353447826646,\n",
       "  'worst area': 1.4332237868207693,\n",
       "  'worst smoothness': 1.993402211865443,\n",
       "  'worst compactness': 0.8646836438651355,\n",
       "  'worst concavity': 1.3882655471454963,\n",
       "  'worst concave points': 1.0898377217385602,\n",
       "  'worst symmetry': 0.5707716568830431,\n",
       "  'worst fractal dimension': 0.7444861349012516,\n",
       "  'target': 0}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset = preprocessor.transform(train_dataset)\n",
    "valid_dataset = preprocessor.transform(valid_dataset)\n",
    "train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2ab598f",
   "metadata": {},
   "source": [
    "Using `take()`, to see that Ray Data zero-centered and rescaled the values to be roughly between -1 and 1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "128cb831",
   "metadata": {},
   "source": [
    "> **Data processing note**:  \n",
    "> For more advanced data loading and preprocessing techniques, see the [comprehensive guide](https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html). Ray Data also supports performant joins, filters, aggregations, and other operations for more structured data processing, if required."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76b534fa",
   "metadata": {},
   "source": [
    "## Model training with XGBoost\n",
    "\n",
    "### Checkpointing configuration\n",
    "\n",
    "Checkpointing is a powerful feature that enables you to resume training from the last checkpoint in case of interruptions. Checkpointing is particularly useful for long-running training sessions.\n",
    "\n",
    "[`XGBoostTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.XGBoostTrainer.html) implements checkpointing out of the box. Configure [`CheckpointConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.CheckpointConfig.html) to set the checkpointing frequency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9787bb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.train import CheckpointConfig, Result, RunConfig, ScalingConfig\n",
    "\n",
    "# Configure checkpointing to save progress during training.\n",
    "run_config = RunConfig(\n",
    "    checkpoint_config=CheckpointConfig(\n",
    "        # Checkpoint every 10 iterations.\n",
    "        checkpoint_frequency=10,\n",
    "        # Only keep the latest checkpoint.\n",
    "        num_to_keep=1,\n",
    "    ),\n",
    "    ## For multi-node clusters, configure storage that's accessible\n",
    "    ## across all worker nodes with `storage_path=\"s3://...\"`.\n",
    "    storage_path=storage_path,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "feaee233",
   "metadata": {},
   "source": [
    "> **Note**: Once you enable checkpointing, you can follow [this guide](https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html) to enable fault tolerance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9887b8e",
   "metadata": {},
   "source": [
    "### Training with XGBoost\n",
    "\n",
    "Pass training parameters as a dictionary, similar to the original [`xgboost.train()`](https://xgboost.readthedocs.io/en/stable/parameter.html) function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a173cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import xgboost\n",
    "from ray.train.xgboost import RayTrainReportCallback, XGBoostTrainer\n",
    "\n",
    "NUM_WORKERS = 4\n",
    "USE_GPU = True\n",
    "\n",
    "\n",
    "def train_fn_per_worker(config: dict):\n",
    "    \"\"\"Training function that runs on each worker.\n",
    "\n",
    "    This function:\n",
    "    1. Gets the dataset shard for this worker\n",
    "    2. Converts to pandas for XGBoost\n",
    "    3. Separates features and labels\n",
    "    4. Creates DMatrix objects\n",
    "    5. Trains the model using distributed communication\n",
    "    \"\"\"\n",
    "    # Get this worker's dataset shard.\n",
    "    train_ds, val_ds = (\n",
    "        ray.train.get_dataset_shard(\"train\"),\n",
    "        ray.train.get_dataset_shard(\"validation\"),\n",
    "    )\n",
    "\n",
    "    # Materialize the data and convert to pandas.\n",
    "    train_ds = train_ds.materialize().to_pandas()\n",
    "    val_ds = val_ds.materialize().to_pandas()\n",
    "\n",
    "    # Separate the labels from the features.\n",
    "    train_X, train_y = train_ds.drop(\"target\", axis=1), train_ds[\"target\"]\n",
    "    eval_X, eval_y = val_ds.drop(\"target\", axis=1), val_ds[\"target\"]\n",
    "\n",
    "    # Convert the data into DMatrix format for XGBoost.\n",
    "    dtrain = xgboost.DMatrix(train_X, label=train_y)\n",
    "    deval = xgboost.DMatrix(eval_X, label=eval_y)\n",
    "\n",
    "    # Do distributed data-parallel training.\n",
    "    # Ray Train sets up the necessary coordinator processes and\n",
    "    # environment variables for workers to communicate with each other.\n",
    "    _booster = xgboost.train(\n",
    "        config[\"xgboost_params\"],\n",
    "        dtrain=dtrain,\n",
    "        evals=[(dtrain, \"train\"), (deval, \"validation\")],\n",
    "        num_boost_round=10,\n",
    "        # Handles metric logging and checkpointing.\n",
    "        callbacks=[RayTrainReportCallback()],\n",
    "    )\n",
    "\n",
    "\n",
    "# Parameters for the XGBoost model.\n",
    "model_config = {\n",
    "    \"xgboost_params\": {\n",
    "        \"objective\": \"binary:logistic\",\n",
    "        \"eval_metric\": [\"logloss\", \"error\"],\n",
    "    }\n",
    "}\n",
    "\n",
    "trainer = XGBoostTrainer(\n",
    "    train_fn_per_worker,\n",
    "    train_loop_config=model_config,\n",
    "    # Register the data subsets.\n",
    "    datasets={\"train\": train_dataset, \"validation\": valid_dataset},\n",
    "    # See \"Scaling strategies\" for more details.\n",
    "    scaling_config=ScalingConfig(\n",
    "        # Number of workers for data parallelism.\n",
    "        num_workers=NUM_WORKERS,\n",
    "        # Set to True to use GPU acceleration.\n",
    "        use_gpu=USE_GPU,\n",
    "    ),\n",
    "    run_config=run_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a70c0a68",
   "metadata": {},
   "source": [
    "> **Ray Train benefits**:\n",
    "> \n",
    "> - **Multi-node orchestration**: Automatically handles multi-node, multi-GPU setup without manual SSH or hostfile configurations\n",
    "> - **Built-in fault tolerance**: Supports automatic retry of failed workers and can continue from the last checkpoint\n",
    "> - **Flexible training strategies**: Supports various parallelism strategies beyond just data parallel training\n",
    "> - **Heterogeneous cluster support**: Define per-worker resource requirements and run on mixed hardware\n",
    "> \n",
    "> Ray Train integrates with popular frameworks like PyTorch, TensorFlow, XGBoost, and more. For enterprise needs, [RayTurbo Train](https://docs.anyscale.com/rayturbo/rayturbo-train) offers additional features like elastic training, advanced monitoring, and performance optimization.\n",
    ">\n",
    "> <img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/train_integrations.png\" width=500>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30fe32cf",
   "metadata": {},
   "source": [
    "Next, train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005f33bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(TrainController pid=19121)\u001b[0m Attempting to start training worker group of size 5 with the following resources: [{'GPU': 1}] * 5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(autoscaler +1m31s)\u001b[0m [autoscaler] [8xA10G:192CPU-768GB] Attempting to add 1 node(s) to the cluster (increasing from 0 to 1).\n",
      "\u001b[36m(autoscaler +1m31s)\u001b[0m [autoscaler] Launching instances failed: NewInstances[g5.48xlarge;num:1;all:false]: could not launch any instances: api error Unsupported: Instance type g5.48xlarge is not supported in zone us-west-2d.\n",
      "\u001b[36m(autoscaler +1m31s)\u001b[0m [autoscaler] [1xA10G:16CPU-64GB] Attempting to add 5 node(s) to the cluster (increasing from 0 to 5).\n",
      "\u001b[36m(autoscaler +1m31s)\u001b[0m [autoscaler] Launching instances failed: NewInstances[g5.4xlarge;num:5;all:false]: could not launch any instances: api error Unsupported: Instance type g5.4xlarge is not supported in zone us-west-2d.\n",
      "\u001b[36m(autoscaler +1m31s)\u001b[0m [autoscaler] [1xA10G:32CPU-128GB] Attempting to add 5 node(s) to the cluster (increasing from 0 to 5).\n",
      "\u001b[36m(autoscaler +1m36s)\u001b[0m [autoscaler] Launching instances failed: NewInstances[g5.8xlarge;num:5;all:false]: could not launch any instances: api error Unsupported: Instance type g5.8xlarge is not supported in zone us-west-2d.\n",
      "\u001b[36m(autoscaler +1m36s)\u001b[0m [autoscaler] [1xL4:4CPU-16GB] Attempting to add 1 node(s) to the cluster (increasing from 0 to 1).\n",
      "\u001b[36m(autoscaler +1m36s)\u001b[0m [autoscaler] [4xL4:48CPU-192GB] Attempting to add 1 node(s) to the cluster (increasing from 0 to 1).\n",
      "\u001b[36m(autoscaler +1m36s)\u001b[0m [autoscaler] [4xL4:48CPU-192GB] Launched 1 instances.\n",
      "\u001b[36m(autoscaler +1m36s)\u001b[0m [autoscaler] [1xL4:4CPU-16GB] Launched 1 instances.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(TrainController pid=19121)\u001b[0m Retrying the launch of the training worker group. The previous launch attempt encountered the following failure:\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m The worker group startup timed out after 30.0 seconds waiting for 5 workers. Potential causes include: (1) temporary insufficient cluster resources while waiting for autoscaling (ignore this warning in this case), (2) infeasible resource request where the provided `ScalingConfig` cannot be satisfied), and (3) transient network issues. Set the RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S environment variable to increase the timeout.\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m Attempting to start training worker group of size 5 with the following resources: [{'GPU': 1}] * 5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(autoscaler +2m21s)\u001b[0m [autoscaler] Cluster upscaled to {12 CPU, 1 GPU}.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(TrainController pid=19121)\u001b[0m Retrying the launch of the training worker group. The previous launch attempt encountered the following failure:\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m The worker group startup timed out after 30.0 seconds waiting for 5 workers. Potential causes include: (1) temporary insufficient cluster resources while waiting for autoscaling (ignore this warning in this case), (2) infeasible resource request where the provided `ScalingConfig` cannot be satisfied), and (3) transient network issues. Set the RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S environment variable to increase the timeout.\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m Attempting to start training worker group of size 5 with the following resources: [{'GPU': 1}] * 5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(autoscaler +2m31s)\u001b[0m [autoscaler] Cluster upscaled to {60 CPU, 5 GPU}.\n",
      "\u001b[33m(raylet)\u001b[0m WARNING: 4 PYTHON worker processes have been started on node: dc30e171b93f61245644ba4d0147f8b27f64e9e1eaf34d1bb63c9c99 with address: 10.0.23.200. This could be a result of using a large number of actors, or due to tasks blocked in ray.get() calls (see https://github.com/ray-project/ray/issues/3644 for some discussion of workarounds).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=3285, ip=10.0.223.105)\u001b[0m [21:04:38] Task [xgboost.ray-rank=00000002]:fa43387771ebd5738fd50b6303000000 got rank 2\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:42] [0]\ttrain-logloss:0.44514\ttrain-error:0.04051\tvalidation-logloss:0.43997\tvalidation-error:0.04706\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:44] [1]\ttrain-logloss:0.31649\ttrain-error:0.01772\tvalidation-logloss:0.31594\tvalidation-error:0.04706\n",
      "\u001b[36m(RayTrainWorker pid=2313, ip=10.0.223.33)\u001b[0m [21:04:38] Task [xgboost.ray-rank=00000004]:a6ed8004330660f5a370531f03000000 got rank 4\u001b[32m [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)\u001b[0m\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:46] [2]\ttrain-logloss:0.23701\ttrain-error:0.01266\tvalidation-logloss:0.24072\tvalidation-error:0.02353\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:48] [3]\ttrain-logloss:0.18165\ttrain-error:0.00759\tvalidation-logloss:0.19038\tvalidation-error:0.01176\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:50] [4]\ttrain-logloss:0.14258\ttrain-error:0.00759\tvalidation-logloss:0.14917\tvalidation-error:0.01176\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:52] [5]\ttrain-logloss:0.11360\ttrain-error:0.00759\tvalidation-logloss:0.12113\tvalidation-error:0.01176\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:54] [6]\ttrain-logloss:0.09207\ttrain-error:0.00759\tvalidation-logloss:0.10018\tvalidation-error:0.01176\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:56] [7]\ttrain-logloss:0.07616\ttrain-error:0.00506\tvalidation-logloss:0.08632\tvalidation-error:0.01176\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:04:58] [8]\ttrain-logloss:0.06419\ttrain-error:0.00506\tvalidation-logloss:0.07705\tvalidation-error:0.01176\n",
      "\u001b[36m(TrainController pid=19121)\u001b[0m [21:05:00] [9]\ttrain-logloss:0.05463\ttrain-error:0.00506\tvalidation-logloss:0.06741\tvalidation-error:0.01176\n",
      "\u001b[36m(RayTrainWorker pid=3284, ip=10.0.223.105)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/user_storage/ray_train_run-2025-04-16_21-03-13/checkpoint_2025-04-16_21-05-00.160991)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Result(metrics=OrderedDict({'train-logloss': 0.05463397157248817, 'train-error': 0.00506329113924051, 'validation-logloss': 0.06741214815308066, 'validation-error': 0.01176470588235294}), checkpoint=Checkpoint(filesystem=local, path=/mnt/user_storage/ray_train_run-2025-04-16_21-03-13/checkpoint_2025-04-16_21-05-00.160991), error=None, path='/mnt/user_storage/ray_train_run-2025-04-16_21-03-13', metrics_dataframe=   train-logloss  train-error  validation-logloss  validation-error\n",
       "0       0.054634     0.005063            0.067412          0.011765, best_checkpoints=[(Checkpoint(filesystem=local, path=/mnt/user_storage/ray_train_run-2025-04-16_21-03-13/checkpoint_2025-04-16_21-05-00.160991), OrderedDict({'train-logloss': 0.05463397157248817, 'train-error': 0.00506329113924051, 'validation-logloss': 0.06741214815308066, 'validation-error': 0.01176470588235294}))], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7ea450adb130>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result: Result = trainer.fit()\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acf06ba2",
   "metadata": {},
   "source": [
    "At the beginning of the training job, Ray started requesting GPU nodes to satisfy the training job's requirement of five GPU workers.\n",
    "\n",
    "Ray Train returns a [`ray.train.Result`](https://docs.ray.io/en/latest/train/api/doc/ray.train.Result.html) object, which contains important properties such as metrics, checkpoint information, and error details:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "929c13bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('train-logloss', 0.05463397157248817),\n",
       "             ('train-error', 0.00506329113924051),\n",
       "             ('validation-logloss', 0.06741214815308066),\n",
       "             ('validation-error', 0.01176470588235294)])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics = result.metrics\n",
    "metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b18221b",
   "metadata": {},
   "source": [
    "The expected output are similar to the following:\n",
    "\n",
    "```python\n",
    "OrderedDict([('train-logloss', 0.05463397157248817),\n",
    "             ('train-error', 0.00506329113924051),\n",
    "             ('validation-logloss', 0.06741214815308066),\n",
    "             ('validation-error', 0.01176470588235294)])\n",
    "```\n",
    "\n",
    "See that the Ray Train logs metrics based on the values you configured in `eval_metric` and `evals`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e15f51a",
   "metadata": {},
   "source": [
    "You can also reconstruct the trained model from the checkpoint directory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87892b1f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<xgboost.core.Booster at 0x7ea4531beea0>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "booster = RayTrainReportCallback.get_model(result.checkpoint)\n",
    "booster"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2f0523a",
   "metadata": {},
   "source": [
    "## Model registry\n",
    "\n",
    "Now that you've trained the model, save it to a model registry for future use. As this is a distributed training workload, the model registry storage needs to be accessible from all workers in the cluster. This storage can be S3, NFS, or another network-attached solution. Anyscale simplifies this process by automatically creating and mounting [shared storage options](https://docs.anyscale.com/configuration/storage/#storage-shared-across-nodes) on every cluster node, ensuring that model artifacts are readable and writable across the distributed environment.\n",
    "\n",
    "The MLflow tracking server stores experiment metadata and model artifacts in the shared storage location, making them available for future model serving, evaluation, or retraining workflows. Ray also integrates with [other experiment trackers](https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cba23e9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025/04/16 21:07:07 INFO mlflow.tracking.fluent: Experiment with name 'breast_cancer_all_features' does not exist. Creating a new experiment.\n"
     ]
    }
   ],
   "source": [
    "import shutil\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import mlflow\n",
    "\n",
    "from dist_xgboost.constants import (\n",
    "    experiment_name,\n",
    "    model_fname,\n",
    "    model_registry,\n",
    "    preprocessor_fname,\n",
    ")\n",
    "\n",
    "\n",
    "def clean_up_old_runs():\n",
    "    # Clean up old MLflow runs.\n",
    "    os.path.isdir(model_registry) and shutil.rmtree(model_registry)\n",
    "    # mlflow.delete_experiment(experiment_name)\n",
    "    os.makedirs(model_registry, exist_ok=True)\n",
    "\n",
    "\n",
    "def log_run_to_mlflow(model_config, result, preprocessor_path):\n",
    "    # Create a model registry in user storage.\n",
    "    mlflow.set_tracking_uri(f\"file:{model_registry}\")\n",
    "\n",
    "    # Create a new experiment and log metrics and artifacts.\n",
    "    mlflow.set_experiment(experiment_name)\n",
    "    with mlflow.start_run(description=\"xgboost breast cancer classifier on all features\"):\n",
    "        mlflow.log_params(model_config)\n",
    "        mlflow.log_metrics(result.metrics)\n",
    "\n",
    "        # Selectively log just the preprocessor and model weights.\n",
    "        with TemporaryDirectory() as tmp_dir:\n",
    "            shutil.copy(\n",
    "                os.path.join(result.checkpoint.path, model_fname),\n",
    "                os.path.join(tmp_dir, model_fname),\n",
    "            )\n",
    "            shutil.copy(\n",
    "                preprocessor_path,\n",
    "                os.path.join(tmp_dir, preprocessor_fname),\n",
    "            )\n",
    "\n",
    "            mlflow.log_artifacts(tmp_dir)\n",
    "\n",
    "\n",
    "clean_up_old_runs()\n",
    "log_run_to_mlflow(model_config, result, preprocessor_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "340529d0",
   "metadata": {},
   "source": [
    "Start the MLflow server to view the experiments:\n",
    "\n",
    "`mlflow server -h 0.0.0.0 -p 8080 --backend-store-uri {model_registry}`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c529b2c",
   "metadata": {},
   "source": [
    "To view the dashboard, go to the **Overview tab** > **Open Ports** > `8080`.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/mlflow.png\" width=685>\n",
    "\n",
    "You can also view the Ray Dashboard and Train workload dashboards:\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/train_metrics.png\" width=700>\n",
    "\n",
    "You can retrieve the best model from the registry:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d394cdf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/mnt/user_storage/mlflow/290203875164933232/eb2666ca6cee4792bfda41a02b194d87/artifacts'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from dist_xgboost.data import get_best_model_from_registry\n",
    "\n",
    "best_model, artifacts_dir = get_best_model_from_registry()\n",
    "artifacts_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11f54394",
   "metadata": {},
   "source": [
    "### Production deployment\n",
    "\n",
    "You can wrap the training workload as a production-grade [Anyscale Job](https://docs.anyscale.com/platform/jobs/). See the [API ref](https://docs.anyscale.com/reference/job-api/) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d78b5c0",
   "metadata": {
    "tags": [
     "remove-cell-ci"
    ]
   },
   "outputs": [],
   "source": [
    "from dist_xgboost.constants import root_dir\n",
    "\n",
    "os.environ[\"WORKING_DIR\"] = root_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eae9e135",
   "metadata": {},
   "source": [
    "Then submit the job using the `anyscale` CLI command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdb286b7",
   "metadata": {
    "tags": [
     "remove-cell-ci"
    ],
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Output\n",
      "(anyscale +0.9s) Submitting job with config JobConfig(name='train-xboost-breast-cancer-model', image_uri=None, compute_config=None, env_vars=None, py_modules=None, py_executable=None, cloud=None, project=None, ray_version=None, job_queue_config=None).\n",
      "(anyscale +2.6s) Uploading local dir '/home/ray/default/e2e-xgboost' to cloud storage.\n",
      "(anyscale +3.8s) Including workspace-managed pip dependencies.\n",
      "(anyscale +4.2s) Job 'train-xboost-breast-cancer-model' submitted, ID: 'prodjob_bkbpnmhytt3ljt8ftlnyumjxdj'.\n",
      "(anyscale +4.2s) View the job in the UI: https://console.anyscale.com/jobs/prodjob_bkbpnmhytt3ljt8ftlnyumjxdj\n",
      "(anyscale +4.2s) Use `--wait` to wait for the job to run and stream logs.\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "# Production batch job -- note that this is a bash cell\n",
    "! anyscale job submit --name=train-xboost-breast-cancer-model \\\n",
    "  --containerfile=\"${WORKING_DIR}/containerfile\" \\\n",
    "  --working-dir=\"${WORKING_DIR}\" \\\n",
    "  --exclude=\"\" \\\n",
    "  --max-retries=0 \\\n",
    "  --wait \\\n",
    "  -- cd notebooks && jupyter nbconvert --to script 01-Distributed_Training.ipynb && ipython 01-Distributed_Training.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "557b6050",
   "metadata": {},
   "source": [
    "> - The `containerfile` defines the dependencies, but you can also use a pre-built image\n",
    "> - You can specify compute requirements as a [compute config](https://docs.anyscale.com/configuration/compute-configuration/) or inline in a [job config](https://docs.anyscale.com/reference/job-api#job-cli)\n",
    "> - When launched from a workspace without specifying compute, it defaults to the compute configuration of the workspace"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bda2a7b0",
   "metadata": {},
   "source": [
    "## Scaling strategies\n",
    "\n",
    "One of the key advantages of Ray Train is its ability to effortlessly scale training workloads. By adjusting the [`ScalingConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html), you can optimize resource utilization and reduce training time.\n",
    "\n",
    "### Scaling examples\n",
    "\n",
    "**Multi-node CPU example:** 4 nodes with 8 CPUs each\n",
    "\n",
    "```python\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=4,\n",
    "    resources_per_worker={\"CPU\": 8},\n",
    ")\n",
    "```\n",
    "\n",
    "**Single-node multi-GPU example:** 1 node with 8 CPUs and 4 GPUs\n",
    "\n",
    "```python\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=4,\n",
    "    use_gpu=True,\n",
    ")\n",
    "```\n",
    "\n",
    "**Multi-node multi-GPU example:** 4 nodes with 8 CPUs and 4 GPUs each\n",
    "\n",
    "```python\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=16,\n",
    "    use_gpu=True,\n",
    ")\n",
    "```\n",
    "\n",
    "> **Important:** For multi-node clusters, you must specify a shared storage location, such as cloud storage or NFS, in the `run_config`. Using a local path raises an error during checkpointing.\n",
    ">\n",
    "> ```python\n",
    "> trainer = XGBoostTrainer(\n",
    ">     ..., run_config=ray.train.RunConfig(storage_path=\"s3://...\")\n",
    "> )\n",
    "> ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdab5180",
   "metadata": {},
   "source": [
    "### Worker configuration guidelines\n",
    "\n",
    "The optimal number of workers depends on the workload and cluster setup:\n",
    "\n",
    "- For **CPU-only training**, generally use one worker per node. XGBoost can leverage multiple CPUs with threading.\n",
    "- For **multi-GPU training**, use one worker per GPU.\n",
    "- For **heterogeneous clusters**, consider the greatest common divisor of CPU counts.\n",
    "\n",
    "### GPU acceleration\n",
    "\n",
    "To use GPUs for training:\n",
    "\n",
    "1. Start one actor per GPU with `use_gpu=True`\n",
    "2. Set GPU-compatible parameters, for example, `tree_method=\"gpu_hist\"` for XGBoost\n",
    "3. Divide CPUs evenly across actors on each machine\n",
    "\n",
    "#### Example:\n",
    "\n",
    "```python\n",
    "trainer = XGBoostTrainer(\n",
    "    scaling_config=ScalingConfig(\n",
    "        # Number of workers to use for data parallelism.\n",
    "        num_workers=2,\n",
    "        # Whether to use GPU acceleration.\n",
    "        use_gpu=True,\n",
    "    ),\n",
    "    params={\n",
    "        # XGBoost specific params.\n",
    "        \"tree_method\": \"gpu_hist\",  # GPU-specific parameter\n",
    "        \"eval_metric\": [\"logloss\", \"error\"],\n",
    "    },\n",
    "    ...\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "021456cc",
   "metadata": {},
   "source": [
    "For more advanced topics, see:\n",
    "- [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for hyperparameter optimization\n",
    "- [Ray Serve](https://docs.ray.io/en/latest/serve/index.html) for model deployment\n",
    "- [Ray Data](https://docs.ray.io/en/latest/data/data.html) for more advanced data processing"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
